{% import 'macro.jinja2' as MRO with context %}

confusion_matrix_data = metrics.confusion_matrix(y_test, y_pred)

plt.rc('font',family='Times New Roman',size='16')
fig, ax = plt.subplots(figsize=(6,6))
im = ax.imshow(confusion_matrix_data, interpolation='nearest', cmap=plt.cm.Blues)
# ax.figure.colorbar(im, ax=ax)

ax.set(xticks=np.arange(confusion_matrix_data.shape[1]),
       yticks=np.arange(confusion_matrix_data.shape[0]),
       xticklabels=classes, yticklabels=classes,
       title='Confusion Matrix',
       ylabel='Actual',
       xlabel='Predicted')

ax.set_xticks(np.arange(confusion_matrix_data.shape[1]+1)-.5, minor=True)
ax.set_yticks(np.arange(confusion_matrix_data.shape[0]+1)-.5, minor=True)
ax.grid(which="minor", color="gray", linestyle='-', linewidth=0.2)
ax.tick_params(which="minor", bottom=False, left=False)

plt.setp(ax.get_xticklabels(), rotation=45, ha="right",rotation_mode="anchor")


thresh = confusion_matrix_data.max() / 2.
for i in range(confusion_matrix_data.shape[0]):
    for j in range(confusion_matrix_data.shape[1]):
        if int(confusion_matrix_data[i, j]*100 + 0.5) > 0:
            ax.text(j, i, confusion_matrix_data[i, j], ha="center", va="center", color="white" if confusion_matrix_data[i, j] > thresh else "black")
fig.tight_layout()
plt.show()
